{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# File I/O\n",
    "\n",
    "So far we discussed how to process data and how\n",
    "to build, train, and test deep learning models.\n",
    "However, at some point, we will hopefully be happy enough\n",
    "with the learned models that we will want\n",
    "to save the results for later use in various contexts\n",
    "(perhaps even to make predictions in deployment).\n",
    "Additionally, when running a long training process,\n",
    "the best practice is to periodically save intermediate results (checkpointing)\n",
    "to ensure that we do not lose several days worth of computation\n",
    "if we trip over the power cord of our server.\n",
    "Thus it is time we learned how to load and store\n",
    "both individual weight vectors and entire models.\n",
    "This section addresses both issues.\n",
    "\n",
    "## Loading and Saving Tensors\n",
    "\n",
    "For individual tensors, we can convert NDArrays into\n",
    "`byte[]`s by calling their `encode()` function.\n",
    "We can then convert them back into NDArrays by calling\n",
    "the NDArray `decode()` function and passing in \n",
    "an `NDManager`(to manage the created NDArray) and `byte[]` (the wanted tensor).\n",
    "\n",
    "We can then use `FileInputStream` and `FileOutputStream` to \n",
    "read and write these to files respectively. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.8.0\n",
    "%maven ai.djl:basicdataset:0.8.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "\n",
    "%maven ai.djl.mxnet:mxnet-engine:0.8.0\n",
    "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.MalformedModelException;\n",
    "import ai.djl.Model;\n",
    "import ai.djl.inference.Predictor;\n",
    "import ai.djl.ndarray.NDArray;\n",
    "import ai.djl.ndarray.NDList;\n",
    "import ai.djl.ndarray.NDManager;\n",
    "import ai.djl.ndarray.types.DataType;\n",
    "import ai.djl.ndarray.types.Shape;\n",
    "import ai.djl.nn.Activation;\n",
    "import ai.djl.nn.Parameter;\n",
    "import ai.djl.nn.SequentialBlock;\n",
    "import ai.djl.nn.core.Linear;\n",
    "import ai.djl.training.initializer.XavierInitializer;\n",
    "import ai.djl.translate.NoopTranslator;\n",
    "import ai.djl.translate.TranslateException;\n",
    "import ai.djl.util.PairList;\n",
    "import ai.djl.util.Utils;\n",
    "\n",
    "import java.io.*;\n",
    "import java.nio.file.Files;\n",
    "import java.util.Arrays;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "NDArray x = manager.arange(4);\n",
    "try (FileOutputStream fos = new FileOutputStream(\"x-file\")) {\n",
    "    fos.write(x.encode());\n",
    "}\n",
    "x;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "We can now read this data from the stored file back into memory.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDArray x2;\n",
    "try (FileInputStream fis = new FileInputStream(\"x-file\")) {\n",
    "    // We use the `Utils` method `toByteArray()` to read \n",
    "    // from a `FileInputStream` and return it as a `byte[]`.\n",
    "    x2 = NDArray.decode(manager, Utils.toByteArray(fis));\n",
    "}\n",
    "x2;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "## Model Parameters\n",
    "\n",
    "Saving individual weight vectors (or other tensors) is useful,\n",
    "but it gets very tedious if we want to save\n",
    "(and later load) an entire model.\n",
    "After all, we might have hundreds of\n",
    "parameter groups sprinkled throughout.\n",
    "For this reason the framework provides built-in functionality\n",
    "to load and save entire networks.\n",
    "An important detail to note is that this\n",
    "saves model *parameters* and not the entire model.\n",
    "For example, if we have a 3-layer MLP,\n",
    "we need to specify the *architecture* separately.\n",
    "The reason for this is that the models themselves can contain arbitrary code,\n",
    "hence they cannot be serialized as naturally.\n",
    "Thus, in order to reinstate a model, we need\n",
    "to generate the architecture in code\n",
    "and then load the parameters from disk.\n",
    "Let us start with our familiar MLP.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public SequentialBlock createMLP() {\n",
    "    SequentialBlock mlp = new SequentialBlock();\n",
    "    mlp.add(Linear.builder().setUnits(256).build());\n",
    "    mlp.add(Activation.reluBlock());\n",
    "    mlp.add(Linear.builder().setUnits(10).build());\n",
    "    return mlp;\n",
    "}\n",
    "\n",
    "SequentialBlock original = createMLP();\n",
    "\n",
    "NDArray x = manager.randomUniform(0, 1, new Shape(2, 5));\n",
    "\n",
    "original.setInitializer(new XavierInitializer());\n",
    "original.initialize(manager, DataType.FLOAT32, x.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"mlp\");\n",
    "model.setBlock(original);\n",
    "\n",
    "Predictor predictor = model.newPredictor(new NoopTranslator());\n",
    "\n",
    "NDArray y = ((NDList) predictor.predict(new NDList(x))).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "Next, we store the parameters of the model as a file with the name `mlp.param`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 21,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Save file\n",
    "File mlpParamFile = new File(\"mlp.param\");\n",
    "DataOutputStream os = new DataOutputStream(Files.newOutputStream(mlpParamFile.toPath()));\n",
    "original.saveParameters(os);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "To recover the model, we instantiate a clone\n",
    "of the original MLP model.\n",
    "Instead of randomly initializing the model parameters,\n",
    "we read the parameters stored in the file directly.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 25,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Create duplicate of network architecture\n",
    "SequentialBlock clone = createMLP();\n",
    "// Load Parameters\n",
    "clone.loadParameters(manager, new DataInputStream(Files.newInputStream(mlpParamFile.toPath())));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let us directly compare the parameters of both models. We get the `Parameter`'s respective array at each index for both `PairList`s and then compare the two.\n",
    "\n",
    "\n",
    "Note that we cannot compare the `Parameter`'s directly. When we load the `Parameter`, a new unique id is generated for it. Instead, we can check that the `NDArray`s are equal.\n",
    "\n",
    "They should be identical if loaded properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Original model's parameters\n",
    "PairList<String, Parameter> originalParams = original.getParameters();\n",
    "// Loaded model's parameters\n",
    "PairList<String, Parameter> loadedParams = clone.getParameters();\n",
    "\n",
    "for (int i = 0; i < originalParams.size(); i++) {\n",
    "    if (originalParams.valueAt(i).getArray().equals(loadedParams.valueAt(i).getArray())) {\n",
    "        System.out.printf(\"True \");\n",
    "    }\n",
    "    else {\n",
    "        System.out.printf(\"False \");\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 28
   },
   "source": [
    "Since both instances have the same model parameters,\n",
    "the computation result of the same input `x` should be the same.\n",
    "Let us verify this.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 29,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "Model modelClone = Model.newInstance(\"mlp\");\n",
    "modelClone.setBlock(clone);\n",
    "\n",
    "Predictor predictor2 = modelClone.newPredictor(new NoopTranslator());\n",
    "\n",
    "NDArray yClone = ((NDList) predictor2.predict(new NDList(x))).singletonOrThrow();\n",
    "\n",
    "y.eq(yClone);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 32
   },
   "source": [
    "## Summary\n",
    "\n",
    "* The `decode` and `encode` functions along with `FileStreams` can be used to perform File I/O for tensor objects.\n",
    "* Saving the architecture has to be done in code rather than in parameters.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Even if there is no need to deploy trained models to a different device, what are the practical benefits of storing model parameters?\n",
    "1. Assume that we want to reuse only parts of a network to be incorporated into a network of a *different* architecture. How would you go about using, say the first two layers from a previous network in a new network.\n",
    "1. How would you go about saving network architecture and parameters? What restrictions would you impose on the architecture?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.5+10-LTS"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
